In [1]:
import netCDF4 as nc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from dotenv import load_dotenv
from pandas.plotting import parallel_coordinates
import importlib
import plotly.express as px
import plotly.graph_objects as go
import os
import glob
import pandas as pd
import json

import utils.db_tools as db_tools
from utils.db_tools import (
    get_db,
    filter_df,
    make_animation,
    get_data,
    metrics_grid,
    plot_grid,
    compute_metrics
)

from classify import classify_trajectories

importlib.reload(db_tools)
Out[1]:
<module 'utils.db_tools' from '/cluster/home/vogtva/pde-solvers-cuda/analysis/utils/db_tools.py'>
In [2]:
model = "bruss"
run_id = "ball_big"
load_dotenv()
data_dir = os.getenv("DATA_DIR")
output_dir = os.getenv("OUT_DIR")
df = pd.read_csv(f"{output_dir}/{model}/{run_id}/classification_metrics_02.csv")
df_class = classify_trajectories(
    df
)
df = df_class.copy()
# df = df[df["filename"].apply(os.path.exists)].reset_index(drop=True)
df["op"] = df["original_point"].astype(str)
In [4]:
df["category"].value_counts()
Out[4]:
category
SS     3007
OSC    2128
I      1676
DSS     387
Name: count, dtype: int64
In [9]:
# Plot the distribution of 'mean_deviation' for each category
plt.figure(figsize=(12, 8))
ax = sns.histplot(data=df_class, x='mean_deviation', hue='category', multiple='stack', kde=False)

plt.xlabel('Mean Deviation')
plt.ylabel('Frequency')
plt.title('Distribution of Mean Deviation by Category')
plt.show()
No description has been provided for this image
In [10]:
# plt.figure(figsize=(10, 6))
# sns.scatterplot(x=df_class['A'], y=df_class['B'], hue=df_class["category"])
# plt.xlabel('A')
# plt.ylabel('B')
# plt.title('Scatter plot of A vs B for Sampling Centers')
# plt.show()

fig = px.scatter(
    df_class,
    x="A",
    y="B",
    color="category",
    title="Scatter plot of A vs B",
    labels={"A": "A", "B": "B"},
    width=800,
    height=800,
)

# Display the plot in the notebook
fig.show()
In [15]:
df.value_counts("category")
Out[15]:
category
SS     3957
OSC    1530
I      1294
DSS     419
Name: count, dtype: int64
In [14]:
def plot_ball_behavior(df, metric="dev"):
    t = np.linspace(0, 100, 100)
    title = ""

    all_metrics = []
    for _, row in df.iterrows():
        d = get_data(row)
        metrics = compute_metrics(row, start_frame=0)
        if metric == "dev":
            title = "Deviation"
            values = metrics[0]
        elif metric == "dt":
            title = "Time Derivative"
            values = metrics[1]
        elif metric == "dx":
            title = "Spatial Derivative"
            values = metrics[2]
        all_metrics.append(values)
    
    # Convert to numpy array for easier computation
    all_metrics = np.array(all_metrics)

    # Compute mean and std
    avg_metric = np.mean(all_metrics, axis=0)
    min_metric = np.min(all_metrics, axis=0)
    std_metric = np.std(all_metrics, axis=0)

    # Create figure
    fig = go.Figure()

    # Add shaded area for standard deviation
    # fig.add_trace(
    #     go.Scatter(
    #         x=np.concatenate([t, t[::-1]]),
    #         y=np.concatenate(
    #             [avg_metric + std_metric, (avg_metric)[::-1]]
    #         ),
    #         fill="toself",
    #         fillcolor="rgba(0,100,80,0.2)",
    #         line=dict(color="rgba(255,255,255,0)"),
    #         showlegend=False,
    #     )
    # )

    # # Add mean line
    # fig.add_trace(
    #     go.Scatter(
    #         x=t,
    #         y=avg_metric,
    #         mode="lines",
    #         name=title,
    #         hovertemplate="Index: %{x}<br>Deviation: %{y:.2f}<extra></extra>",
    #     )
    # )

    # fig.add_trace(
    #      go.Scatter(
    #         x=t,
    #         y=min_metric,
    #         mode="lines",
    #         name="min",
    #         hovertemplate="Index: %{x}<br>Min: %{y:.2f}<extra></extra>",
    #     )
    # )

    fig.add_trace(
            go.Scatter(
                x=t,
                y=values,
                mode="lines",
                name=f"Row {row.name}",  # Use row index or a unique identifier
                hovertemplate="Index: %{x}<br>Value: %{y:.2f}<extra></extra>",
            )
        )
        
    # Update layout
    fig.update_layout(
        title="Deviation Metrics",
        xaxis_title="Time Step/Index",
        yaxis_title="Deviation Value",
        hovermode="x unified",
        showlegend=True,
        template="plotly_white",
    )

    fig.show()
In [15]:
df_class["op"] = df_class["original_point"].astype(str)
for _, df1 in df_class.groupby("op"):
    original_point = df1.iloc[0]["original_point"]
    print(original_point, df1.value_counts("category").to_dict())
    plot_ball_behavior(df1)
{'A': 0.5, 'B': 0.625, 'Du': 1, 'Dv': 11} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 1, 'Dv': 18} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 3, 'Dv': 54} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 1, 'Dv': 11} {'SS': 59, 'I': 1}
{'A': 0.5, 'B': 1.0, 'Du': 1, 'Dv': 18} {'SS': 59, 'I': 1}
{'A': 0.5, 'B': 1.0, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 3, 'Dv': 54} {'SS': 60}
{'A': 0.5, 'B': 1.5, 'Du': 1, 'Dv': 11} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 1, 'Dv': 18} {'OSC': 49, 'I': 11}
{'A': 0.5, 'B': 1.5, 'Du': 1, 'Dv': 4} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 3, 'Dv': 12} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 3, 'Dv': 33} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 3, 'Dv': 54} {'OSC': 55, 'I': 5}
{'A': 0.5, 'B': 2.0, 'Du': 1, 'Dv': 11} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 1, 'Dv': 18} {'OSC': 51, 'I': 9}
{'A': 0.5, 'B': 2.0, 'Du': 1, 'Dv': 4} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 3, 'Dv': 12} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 3, 'Dv': 33} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 3, 'Dv': 54} {'OSC': 59, 'I': 1}
{'A': 1.0, 'B': 1.25, 'Du': 1, 'Dv': 11} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 1, 'Dv': 18} {'SS': 57, 'DSS': 3}
{'A': 1.0, 'B': 1.25, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 3, 'Dv': 54} {'SS': 60}
{'A': 1.0, 'B': 2.0, 'Du': 1, 'Dv': 11} {'I': 30, 'OSC': 13, 'SS': 12, 'DSS': 5}
{'A': 1.0, 'B': 2.0, 'Du': 1, 'Dv': 18} {'I': 30, 'OSC': 21, 'SS': 8, 'DSS': 1}
{'A': 1.0, 'B': 2.0, 'Du': 1, 'Dv': 4} {'OSC': 30, 'SS': 25, 'I': 5}
{'A': 1.0, 'B': 2.0, 'Du': 3, 'Dv': 12} {'OSC': 25, 'SS': 25, 'I': 10}
{'A': 1.0, 'B': 2.0, 'Du': 3, 'Dv': 33} {'I': 33, 'OSC': 16, 'SS': 11}
{'A': 1.0, 'B': 2.0, 'Du': 3, 'Dv': 54} {'I': 33, 'OSC': 14, 'SS': 13}
{'A': 1.0, 'B': 3.0, 'Du': 1, 'Dv': 11} {'OSC': 32, 'I': 27, 'DSS': 1}
{'A': 1.0, 'B': 3.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'I': 23, 'DSS': 6}
{'A': 1.0, 'B': 3.0, 'Du': 1, 'Dv': 4} {'OSC': 51, 'I': 8, 'DSS': 1}
{'A': 1.0, 'B': 3.0, 'Du': 3, 'Dv': 12} {'OSC': 56, 'I': 4}
{'A': 1.0, 'B': 3.0, 'Du': 3, 'Dv': 33} {'OSC': 43, 'DSS': 9, 'I': 8}
{'A': 1.0, 'B': 3.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 1.0, 'B': 4.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'I': 21, 'DSS': 9}
{'A': 1.0, 'B': 4.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 19, 'I': 11}
{'A': 1.0, 'B': 4.0, 'Du': 1, 'Dv': 4} {'OSC': 46, 'I': 14}
{'A': 1.0, 'B': 4.0, 'Du': 3, 'Dv': 12} {'OSC': 57, 'I': 3}
{'A': 1.0, 'B': 4.0, 'Du': 3, 'Dv': 33} {'OSC': 34, 'I': 26}
{'A': 1.0, 'B': 4.0, 'Du': 3, 'Dv': 54} {'OSC': 30, 'I': 29, 'DSS': 1}
{'A': 1.5, 'B': 1.875, 'Du': 1, 'Dv': 11} {'SS': 56, 'I': 4}
{'A': 1.5, 'B': 1.875, 'Du': 1, 'Dv': 18} {'SS': 40, 'I': 16, 'DSS': 4}
{'A': 1.5, 'B': 1.875, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 1.5, 'B': 1.875, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 1.5, 'B': 1.875, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 1.5, 'B': 1.875, 'Du': 3, 'Dv': 54} {'SS': 47, 'I': 13}
{'A': 1.5, 'B': 3.0, 'Du': 1, 'Dv': 11} {'I': 29, 'SS': 25, 'OSC': 4, 'DSS': 2}
{'A': 1.5, 'B': 3.0, 'Du': 1, 'Dv': 18} {'I': 30, 'SS': 30}
{'A': 1.5, 'B': 3.0, 'Du': 1, 'Dv': 4} {'SS': 38, 'OSC': 13, 'I': 7, 'DSS': 2}
{'A': 1.5, 'B': 3.0, 'Du': 3, 'Dv': 12} {'SS': 36, 'I': 15, 'OSC': 9}
{'A': 1.5, 'B': 3.0, 'Du': 3, 'Dv': 33} {'I': 34, 'SS': 20, 'OSC': 6}
{'A': 1.5, 'B': 3.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 22, 'OSC': 8}
{'A': 1.5, 'B': 4.5, 'Du': 1, 'Dv': 11} {'OSC': 30, 'I': 29, 'DSS': 1}
{'A': 1.5, 'B': 4.5, 'Du': 1, 'Dv': 18} {'OSC': 30, 'I': 26, 'DSS': 4}
{'A': 1.5, 'B': 4.5, 'Du': 1, 'Dv': 4} {'OSC': 45, 'I': 15}
{'A': 1.5, 'B': 4.5, 'Du': 3, 'Dv': 12} {'OSC': 45, 'I': 15}
{'A': 1.5, 'B': 4.5, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 1.5, 'B': 4.5, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 1.5, 'B': 6.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'DSS': 20, 'I': 10}
{'A': 1.5, 'B': 6.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 24, 'I': 6}
{'A': 1.5, 'B': 6.0, 'Du': 1, 'Dv': 4} {'OSC': 43, 'I': 11, 'DSS': 6}
{'A': 1.5, 'B': 6.0, 'Du': 3, 'Dv': 12} {'OSC': 45, 'I': 15}
{'A': 1.5, 'B': 6.0, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 1.5, 'B': 6.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 2.0, 'B': 2.5, 'Du': 1, 'Dv': 11} {'SS': 45, 'I': 13, 'DSS': 2}
{'A': 2.0, 'B': 2.5, 'Du': 1, 'Dv': 18} {'SS': 30, 'I': 24, 'DSS': 6}
{'A': 2.0, 'B': 2.5, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 2.0, 'B': 2.5, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 2.0, 'B': 2.5, 'Du': 3, 'Dv': 33} {'SS': 50, 'I': 10}
{'A': 2.0, 'B': 2.5, 'Du': 3, 'Dv': 54} {'SS': 30, 'I': 28, 'DSS': 2}
{'A': 2.0, 'B': 4.0, 'Du': 1, 'Dv': 11} {'I': 30, 'SS': 30}
{'A': 2.0, 'B': 4.0, 'Du': 1, 'Dv': 18} {'I': 30, 'SS': 28, 'OSC': 2}
{'A': 2.0, 'B': 4.0, 'Du': 1, 'Dv': 4} {'SS': 39, 'I': 20, 'DSS': 1}
{'A': 2.0, 'B': 4.0, 'Du': 3, 'Dv': 12} {'SS': 42, 'I': 17, 'DSS': 1}
{'A': 2.0, 'B': 4.0, 'Du': 3, 'Dv': 33} {'I': 31, 'SS': 29}
{'A': 2.0, 'B': 4.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
{'A': 2.0, 'B': 6.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'I': 17, 'DSS': 13}
{'A': 2.0, 'B': 6.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 19, 'I': 11}
{'A': 2.0, 'B': 6.0, 'Du': 1, 'Dv': 4} {'OSC': 33, 'I': 22, 'DSS': 3, 'SS': 2}
{'A': 2.0, 'B': 6.0, 'Du': 3, 'Dv': 12} {'OSC': 36, 'I': 24}
{'A': 2.0, 'B': 6.0, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 2.0, 'B': 6.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 28, 'SS': 2}
{'A': 2.0, 'B': 8.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'DSS': 20, 'I': 10}
{'A': 2.0, 'B': 8.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 28, 'I': 2}
{'A': 2.0, 'B': 8.0, 'Du': 1, 'Dv': 4} {'OSC': 32, 'DSS': 17, 'I': 10}
{'A': 2.0, 'B': 8.0, 'Du': 3, 'Dv': 12} {'OSC': 45, 'I': 15}
{'A': 2.0, 'B': 8.0, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 2.0, 'B': 8.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 5.0, 'B': 10.0, 'Du': 1, 'Dv': 11} {'SS': 30, 'I': 29, 'DSS': 1}
{'A': 5.0, 'B': 10.0, 'Du': 1, 'Dv': 18} {'SS': 30, 'I': 25, 'DSS': 5}
{'A': 5.0, 'B': 10.0, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 5.0, 'B': 10.0, 'Du': 3, 'Dv': 12} {'SS': 59, 'I': 1}
{'A': 5.0, 'B': 10.0, 'Du': 3, 'Dv': 33} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 10.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 15.0, 'Du': 1, 'Dv': 11} {'SS': 30, 'DSS': 18, 'I': 12}
{'A': 5.0, 'B': 15.0, 'Du': 1, 'Dv': 18} {'SS': 30, 'DSS': 23, 'I': 7}
{'A': 5.0, 'B': 15.0, 'Du': 1, 'Dv': 4} {'SS': 33, 'I': 17, 'DSS': 10}
{'A': 5.0, 'B': 15.0, 'Du': 3, 'Dv': 12} {'SS': 30, 'I': 29, 'DSS': 1}
{'A': 5.0, 'B': 15.0, 'Du': 3, 'Dv': 33} {'SS': 30, 'I': 26, 'DSS': 4}
{'A': 5.0, 'B': 15.0, 'Du': 3, 'Dv': 54} {'SS': 30, 'I': 24, 'DSS': 6}
{'A': 5.0, 'B': 20.0, 'Du': 1, 'Dv': 11} {'DSS': 29, 'SS': 28, 'OSC': 2, 'I': 1}
{'A': 5.0, 'B': 20.0, 'Du': 1, 'Dv': 18} {'SS': 30, 'I': 19, 'DSS': 11}
{'A': 5.0, 'B': 20.0, 'Du': 1, 'Dv': 4} {'SS': 30, 'DSS': 25, 'I': 5}
{'A': 5.0, 'B': 20.0, 'Du': 3, 'Dv': 12} {'SS': 30, 'I': 26, 'DSS': 4}
{'A': 5.0, 'B': 20.0, 'Du': 3, 'Dv': 33} {'SS': 30, 'DSS': 18, 'I': 12}
{'A': 5.0, 'B': 20.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 6.25, 'Du': 1, 'Dv': 11} {'SS': 46, 'I': 14}
{'A': 5.0, 'B': 6.25, 'Du': 1, 'Dv': 18} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 6.25, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 5.0, 'B': 6.25, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 5.0, 'B': 6.25, 'Du': 3, 'Dv': 33} {'SS': 41, 'I': 17, 'DSS': 2}
{'A': 5.0, 'B': 6.25, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
In [ ]: